package com.liusu.deeplearning4j;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class TrainModel {
    public static void main(String[] args) throws Exception {
        DataSetIterator mnistTrain = MnistLoader.getMnistTrainData();
        DataSetIterator mnistTest = MnistLoader.getMnistTestData();

        MultiLayerNetwork model = MnistModel.buildModel();

        for (int i = 0; i < 10; i++) { // 训练10个周期
            model.fit(mnistTrain);
            System.out.println("Completed epoch " + i);
        }

        Evaluation eval = model.evaluate(mnistTest);
        System.out.println(eval.stats());
    }
}
